from typing import Any, Dict, Optional, Union, List

import gymnasium as gym
import numpy as np
from omegaconf import DictConfig
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import wandb
from tianshou.data import Batch, ReplayBuffer, to_numpy
from tianshou.policy import DQNPolicy, DDPGPolicy, PPOPolicy
from Causal.ac_infer.Network.network_utils import pytorch_model
import tianshou as ts
import torch
import itertools
import copy

from collections import deque
import matplotlib.pyplot as plt
from Causal import Dynamics
from Policy.Reward.rew_term_trunc_manager import RewardTermTruncManager
from Policy.policy_utils import get_new_indices
from Policy.hindsight_filter import HindsightFilter
from Policy.fpg import FPGPolicy
from State.extractor import Extractor
from State.buffer import VectorGCReplayBufferManager


class GoalPolicy(ts.policy.BasePolicy):
    """
    Wraps around a base policy (which handles the RL components) \
    to return the appropriate updates and actions, updating after \
    changing the state for hindsight.
    """
    def __init__(
        self,
        policy: Union[DQNPolicy, DDPGPolicy],
        dynamics: Dynamics,
        rewtermdone: RewardTermTruncManager,
        extractor: Extractor,
        action_space: gym.Space,
        config: DictConfig,
        hindsight_filter: HindsightFilter,
    ):
        """
        policies: Sequence[Union[DQNPolicy, DDPGPolicy]],
            a set of tianshou policies based on the algorithm of choice
        dynamics: Dynamics,
            dynamics model, used for estimating local dependency graph
        rewtermdone: RewardTermTruncManager,
            corresponding reward, termination and done objects to compute policy specific signals
        extractor: Extractor,
            handles factorization of state
        action_space: gym.Space,
            the environment action space, sometimes not stored in the TS policy so stored here
        """
        super().__init__(action_scaling=True, action_space=action_space, action_bound_method="clip")
        self.policy = policy
        self.dynamics = dynamics
        self.rewtermdone = rewtermdone
        self.extractor = extractor
        self.action_space = action_space
        self.config = config
        self.hindsight_filter = hindsight_filter

        self.n_step = config.policy.n_step
        self.use_prio = config.data.prio.prio
        self.use_her = config.data.her.use_her
        self.computed_her = config.data.her.separate_her

        # For extracting on-policy indices
        self.buffer_last_index = None
        self.on_policy_indices = None

        self.num_factors = config.num_factors
        self.num_edge_classes = 1 # we only support 0/1 edges, not more complex arrangements
        self.graph_size = self.num_factors + (self.num_factors + 1) * self.num_edge_classes
        self.graph_action_n = config.policy.graph_action_n
        # when selecting a graph from the history buffer of seen graphs, mask out unused indices
        self.choose_from_history_action_mask = torch.zeros(self.graph_action_n, dtype=torch.float32, device=config.device)
        self.choose_from_history_action_mask[0] = 1
        self.factor_onehot_helper = np.eye(self.num_factors, dtype=np.float32) # for converting factor index to one hot
        # self.edge_onehot_helper = np.eye(self.num_edge_classes, dtype=np.float32) # would be used for more complex edge classes
        self.count_idx_to_graph = np.flip(list(itertools.product([0, 1], repeat=self.num_factors + 1)), axis=-1).astype(bool)
        self.history_update_ready = False
        # hash graphs for counting purposes
        self.unique_graph_from_hash = {}
        self.unique_graph_from_id = np.zeros((self.graph_action_n, self.graph_size), dtype=np.float32)
        self.unique_graph_to_id = {}                # when using randomly sampled graph, need to map graph to id
        self.unique_graph_index = 0
        self.unique_factor_graph_from_hash = None   # if not none, will save a dictionary of unique graphs for every factor
        # counter for tracking the success rate for reaching goals/graphs 
        self.success_tracker = {}
        self.goal_success_tracker = {}
        self.action_stats = {}

        # counters for goal tracking
        self.state_counts = 0
        self.updates_total = 0
        self.factor_idx = self.hindsight_filter.target_idx if self.hindsight_filter.target_idx > 0 else self.hindsight_filter.target_idx + self.num_factors + 2

        self.graph_info_size = self.num_factors + (self.num_factors + 1)
        self.her_update_achieved_goal_fn = lambda batch: batch
        if self.use_her:
            self.her_update_achieved_goal_fn = lambda batch: self.update_achieved_goal(batch)


    def random_sample(self, num):
        return np.stack([self.action_space.sample() for i in range(num)],axis=0)

    def get_target(self, data, next=False):
        return self.extractor.slice_targets(data, next=next)

    def get_achieved_goal(self, data, use_next_obs):
        # get the achieved upper action goal for HER,
        # ideally only use data.obs.observation, data.obs.desired_goal, data.graph
        bs, desired_goal_dim = data.obs.desired_goal.shape[:-1], data.obs.desired_goal.shape[-1]
        achieved_goal = np.zeros(bs + (self.num_factors, desired_goal_dim))
        # print(achieved_goal.shape)
        for idx in range(self.num_factors):
            achieved_goal[..., idx, :] = self.get_achieved_goal_idx(data, use_next_obs, factor_choice=idx)
        # achieved_goal = np.take_along_axis(achieved_goal, self.factor_idx, axis=-2)[..., 0, :]
        achieved_goal = achieved_goal[..., self.factor_idx, :]
        return achieved_goal

    def get_achieved_goal_idx(self, data, use_next_obs, factor_choice=None):
        # TODO: does not use the idx because of the way get_achieved_goal is implemented
        factored_state = data.get("next_target" if use_next_obs else "target",
                                    self.extractor.slice_targets(data, next=use_next_obs))        # (bs, num_factors, longest)
        # print(factored_state.shape, factor_choice, self.num_factors)
        # print("fs.shape", factored_state.shape)
        achieved_goal = self.extractor.get_achieved_goal_state(factored_state, fidx=factor_choice)                    # (bs, longest)
        # TODO: we could do DIAYN style goal conditioning
        return achieved_goal


    def policy_obs(self, batch, next=False):
        """
        converts batch.obs or batch.obs_next to the\
        sub policy space @param policy index.
        returns the adjusted batch. should be overriden in subclass
        """
        return self.policy.policy_obs(batch, next=next)

    def update_achieved_goal(self, batch):
        # batch.graph = self.dynamics(batch)
        achieved_goal = self.get_achieved_goal(batch, use_next_obs=True)
        batch.obs.achieved_goal = batch.obs_next.achieved_goal = achieved_goal
        return batch

    def process_fn(self, batch, buffer, indices, n_step):
        # TODO: if there is OOM error, we need to change this to sample n_step indices in a for loop
        # samples n_step of indices from the indices given
        # we need to do this again for TD error, but we need this for check_rew
        # check, her_indices = buffer.her_buffer.sample(0)
        # normcheck, norm_indices = buffer.sample(0)
        # normcheck = normcheck[:2000]
        # print("checks", her_indices, norm_indices)
        # for lin in np.concatenate([np.expand_dims(check.rew, axis=-1),np.expand_dims(check.terminated, axis=-1),np.expand_dims(check.truncated, axis=-1),np.expand_dims(check.done, axis=-1), check.obs.achieved_goal, check.obs.desired_goal, check.obs.observation, np.expand_dims(normcheck.rew, axis=-1), np.expand_dims(normcheck.terminated, axis=-1), np.expand_dims(normcheck.truncated, axis=-1), np.expand_dims(normcheck.done, axis=-1), normcheck.obs.achieved_goal, normcheck.obs.desired_goal, normcheck.obs.observation], axis=-1):
        #     print(lin)
        # # print("main checks", )
        # # for lin in np.concatenate([np.expand_dims(normcheck.rew, axis=-1), normcheck.obs.achieved_goal, normcheck.obs.desired_goal, normcheck.obs.observation], axis=-1):
        # #     print(lin)
        # error

        if not self.computed_her:
            assert indices.ndim == 1
            num_indices = len(indices)
            n_step_indices = np.empty(num_indices * n_step, dtype=indices.dtype)
            n_step_indices[:num_indices] = indices
            next_indices = indices
            for i in range(1, n_step):
                next_indices = buffer.next(next_indices)
                n_step_indices[num_indices * i:num_indices * (i + 1)] = next_indices

            n_step_batch = buffer[n_step_indices]
            # n_step_batch = self.update_achieved_goal(n_step_batch)
            # achieved_goal and graph is not used anywhere except for RTT, so we don't need to write it back to the buffer
            # TODO: double-check if this is true
            # buffer.graph[n_step_indices] = n_step_batch.graph
            # buffer.obs.achieved_goal[n_step_indices] = n_step_batch.obs.achieved_goal
            # buffer.obs_next.achieved_goal[n_step_indices] = n_step_batch.obs.achieved_goal
            # print(n_step_batch.obs)
            before_rews = np.expand_dims(copy.deepcopy(buffer.rew[n_step_indices]), axis=-1) # TODO: remove, just for printing
            buffer.rew[n_step_indices] = self.rewtermdone.check_rew(n_step_batch)
            
            # print("after", np.concatenate([before_rews, 
            #                             np.expand_dims(buffer.rew[n_step_indices], axis=-1),
            #                                 np.expand_dims(n_step_indices, axis=-1),
            #                             buffer.obs[n_step_indices].desired_goal, 
            #                             buffer.obs[n_step_indices].achieved_goal], axis=-1))

        batch = buffer[indices]
        return self.policy.process_fn(batch, buffer, indices)

    def post_process_fn(self, batch, buffer, indices):
        self.policy.post_process_fn(batch, buffer, indices)
    
    def check_rew_term_trunc(self, data):
        return self.rewtermdone.check_rew_term_trunc(data)

    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
        """Update policy with a given batch of data.

        :return: A dict, including the data needed to be logged (e.g., loss).

        .. note::

            In order to distinguish the collecting state, updating state and
            testing state, you can check the policy state by ``self.training``
            and ``self.updating``. Please refer to :ref:`policy_state` for more
            detailed explanation.

        .. warning::

            If you use ``torch.distributions.Normal`` and
            ``torch.distributions.Categorical`` to calculate the log_prob,
            please be careful about the shape: Categorical distribution gives
            "[batch_size]" shape while Normal distribution gives "[batch_size,
            1]" shape. The auto-broadcasting of numerical operation with torch
            tensors will amplify this error.
        """
        return self.policy.learn(batch, **kwargs)              # implemented in contained class: ex: tianshou policy

    def update(self, sample_size: int, buffer: Optional[VectorGCReplayBufferManager], **kwargs: Any) -> Dict[str, Any]:
        """ Logic is copied from tianshou.policy.base, but handles multiple policies
        Update the policy network and replay buffer.

        It includes 3 function steps: process_fn, learn, and post_process_fn. In
        addition, this function will change the value of ``self.updating``: it will be
        False before this function and will be True when executing :meth:`update`.
        Please refer to :ref:`policy_state` for more detailed explanation.

        :param int sample_size: 0 means it will extract all the data from the buffer,
            otherwise it will sample a batch with given sample_size.
        :param ReplayBuffer buffer: the corresponding replay buffer.

        :return: A dict, including the data needed to be logged (e.g., loss) from
            ``policy.learn()``.
        """
        if buffer is None:
            return {}

        result = {}
        if isinstance(self.policy, PPOPolicy) or isinstance(self.policy, FPGPolicy):
            # TODO: any on-policy strategy should be using this entry point
            # since it focuses on sampling the most recent states, but the only on-policy algo
            # is PPO
            self.buffer_last_index, indices = get_new_indices(self.buffer_last_index, buffer)
            if len(indices) == 0:
                return dict()
            self.on_policy_indices = indices
            batch = buffer[indices]
            kwargs = {"batch_size": sample_size, "repeat": self.config.policy.ppo.repeat_per_collect}
            n_step = 1
        else:
            # Otherwise, just perform a normal sample operation
            batch, indices = buffer.sample(sample_size,
                                            policy_prio=self.use_prio,
                                            dynamics_prio=False,
                                            her_update_achieved_goal=self.her_update_achieved_goal_fn,
                                            factor_idx=self.factor_idx)
            n_step = self.n_step
        # set flags, 
        self.policy.updating = True
        self.rewtermdone.set_updating(True)
        self.updates_total += 1

        # process sampled batch.obs, .act, .obs_next, rew, returns to self.policy space
        # also generally the returns are computed in process_fn
        batch = self.process_fn(batch, buffer, indices, n_step)
        # print("process", np.concatenate([np.expand_dims(batch.rew, axis=-1),
        #                                  pytorch_model.unwrap(batch.returns),
        #                             batch.obs.desired_goal, 
        #                             batch.obs.achieved_goal], axis=-1))
        # print(indices, batch.rew)
        # learn function already has returns and formatted data, so it just implements
        # the loss function evaluation
        result = self.policy.learn(batch, **kwargs)              # implemented in contained class: ex: tianshou self.policy
        # print(result)
        
        # post_process_fn updates the weights with buffer.update_weight
        self.post_process_fn(batch, buffer, indices)
        if self.policy.lr_scheduler is not None:
            self.policy.lr_scheduler.step()

        self.policy.updating = False
        self.rewtermdone.set_updating(False)
        return result

    def forward(
            self,
            batch: Batch,
            state: Optional[Union[dict, Batch, np.ndarray]] = None,
            **kwargs: Any,
    ) -> Batch:
        # runs forward with every policy in lower
        # if lower is hierarchical, batch.time_lower and batch.
        state = state["lower"] if state is not None else None
        policy = self.policy(batch)
        policy.act = to_numpy(policy.act)
        return policy

    def update_history(self, buffer: VectorGCReplayBufferManager):
        # keep a history of the graphs so that we have a sense of what graphs are more rare
        # and thus more useful
        # technically ,we can also sample graphs from the history
        if not self.history_update_ready or len(buffer) == 0:
            return

        # we are already assuming factor type graphs
        # if self.graph_type != "factor":
        #     raise NotImplementedError

        self.choose_from_history_action_mask[:] = 0
         # add unique graphs to the history
        graph_filter = np.eye(self.num_factors, self.num_factors + 1)
        for factor in range(self.num_factors):
            graph_filter_i = graph_filter[factor]
            factor_onehot = self.factor_onehot_helper[factor]           # (num_factors,)

            for factor_count_idx in np.nonzero(buffer.valid_graph_count[factor])[0]:
                graph = self.count_idx_to_graph[factor_count_idx]       # (num_factors + 1,)

                # filter out trivial graphs (no parents or the object itself as the only parent)
                if not np.any(graph > graph_filter_i):
                    continue
                
                # add the one hot vector for the factor to the graph
                graph = graph.astype(int)
                # parent = self.edge_onehot_helper[graph].flatten()       # ((num_factors + 1) * num_edge_classes,)
                parent = graph.flatten()       # ((num_factors + 1) * num_edge_classes,)
                graph = np.concatenate([factor_onehot, parent])         # (num_factors + (num_factors + 1) * num_edge_classes,)

                graph_key = graph.astype(int).tobytes()
                graph = graph.astype(np.float32)

                # add new graph to history by adding to the hash->graph dict
                # then adding to the length of the ne wgraph index
                if graph_key not in self.unique_graph_from_hash:
                    if self.unique_factor_graph_from_hash is not None:
                        self.unique_factor_graph_from_hash[factor][graph_key] = graph
                    self.unique_graph_from_hash[graph_key] = graph

                    if self.graph_action_space == "choose_from_history" and self.unique_graph_index >= self.graph_action_n:
                        raise ValueError("graph_action_n is too small")

                    # set graph, hash and unique graph pointers
                    self.unique_graph_from_id[self.unique_graph_index] = graph
                    self.choose_from_history_action_mask[self.unique_graph_index] = 1
                    self.unique_graph_to_id[graph_key] = self.unique_graph_index
                    self.unique_graph_index = self.unique_graph_index + 1

                    if graph_key not in self.success_tracker:
                        self.success_tracker[graph_key] = 0
                else:
                    graph_id = self.unique_graph_to_id[graph_key]
                    self.choose_from_history_action_mask[graph_id] = 1

        # keep at least one element in the mask to avoid NaN error
        if not torch.any(self.choose_from_history_action_mask):
            self.choose_from_history_action_mask[0] = 1

        if self.graph_action_space == "choose_from_history" and self.add_count_based_lower:
            # always keep state count maximization policy available
            self.choose_from_history_action_mask[:self.num_factors] = 1


    def logging(self, writer: SummaryWriter, step: int, wdb_logger=None) -> bool:
        # add graph that has never been sampled
        for k, graph in self.unique_graph_from_hash.items():
            graph_id = self.unique_graph_to_id[k]
            graph_valid = self.choose_from_history_action_mask[graph_id]
            if k not in self.action_stats and graph_valid:
                self.action_stats[k] = [graph, 0., 0., 0.]

        num_graphs = len(self.action_stats) * (self.num_factors)
        total_count = np.sum([total_count for _, total_count, _, _ in self.action_stats.values()])

        graph_names = []
        count_percents = []
        reach_graph_percents = []
        count_rews = []
        rtt = self.rewtermdone.rtt_functions[0]

        factor_names = self.extractor.factor_names + ["act"]
        for graph, count, reach_goal_count, reach_graph_count in sorted(self.action_stats.values(),
                                                                        key=lambda x: x[1]):
            for factor in range(graph.shape[0]):
                # factor = graph[:self.num_factors].argmax(axis=-1)
                parent = graph[factor].reshape(self.num_factors + 1, self.num_edge_classes).argmax(axis=-1).astype(bool)
                graph_name = ", ".join([factor_names[i] for i, p in enumerate(parent)
                                        if p])
                graph_name = graph_name + " -> " + factor_names[factor]
                graph_names.append(graph_name)
                count_percents.append(100 * count / total_count)
                reach_graph_percents.append(100 * reach_graph_count / total_count)

        # fig = plt.figure(figsize=(10, num_graphs * 0.4))

        # # plot upper action frequency
        # ax = plt.gca()
        # y = np.arange(num_graphs)

        # align = 'edge' if count_rews else 'center'
        # rects = ax.barh(y, count_percents, align=align, height=0.4, label="action percentage")
        # ax.bar_label(rects, label_type='edge', fmt="%.1f", padding=3)
        # rects = ax.barh(y, reach_graph_percents, align=align, height=0.4, label="achieve percentage")
        # ax.bar_label(rects, labels=[f"{v:.1f}" if v > 1 else "" for v in rects.datavalues], label_type='center')

        # plt.xlim([0, np.max(count_percents) * 1.1])

        # # plot upper action reward
        # if count_rews:
        #     rects = ax.barh(y, count_rews, align='edge', height=-0.4, label="action reward")
        #     ax.bar_label(rects, label_type='edge', fmt="%.5f", padding=3)

        # ax.set_yticks(y)
        # ax.set_yticklabels(graph_names)

        # plt.legend(loc="lower right")
        # fig.tight_layout()
        # writer.add_figure("action_stats", fig, step)
        # if wdb_logger is not None: wdb_logger.log({"action_stats" + str(step): wandb.Image(fig)})
        # plt.close("all")
        self.action_stats = {}

        return True

    def update_stats(self, metrics, desired_goal):
        # this will update the goal reaching statistics
        # called from collector every end of episode
        for goal, success, reach_goal, reach_graph, achieve_graph, lower_updated in zip(desired_goal,
                                                                metrics.success,
                                                                metrics.reached_goal,
                                                                metrics.reached_graph,
                                                                metrics.achieve_graph,
                                                                metrics.updated):
            
            # success in this is if a graph is reached, success is tracked per graph.
            graph = achieve_graph.astype(int)
            graph_hash = graph.tobytes()
            if graph_hash not in self.success_tracker:
                self.success_tracker[graph_hash] = 0
            self.success_tracker[graph_hash] += int(success)

            # action stats records the graph, the number of times that graph has been
            # sampled, the number of times a goal or graph is reached
            if graph_hash not in self.action_stats:
                self.action_stats[graph_hash] = [graph, 0., 0., 0.]
            self.action_stats[graph_hash][1] += 1.
            self.action_stats[graph_hash][2] += float(reach_goal)
            self.action_stats[graph_hash][3] += float(reach_graph)


    def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
        state_dict = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
        state_dict.update({prefix + k: v for k, v in
                           {"unique_graph_from_hash": self.unique_graph_from_hash,
                            "unique_graph_from_id": self.unique_graph_from_id,
                            "unique_graph_to_id": self.unique_graph_to_id,
                            "unique_graph_index": self.unique_graph_index,
                            "choose_from_history_action_mask": self.choose_from_history_action_mask,
                            "success_tracker": self.success_tracker}.items()
                           })
        return state_dict

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        for attr_name in ["unique_graph_from_hash",
                          "unique_graph_from_id",
                          "unique_graph_to_id",
                          "unique_graph_index",
                          "choose_from_history_action_mask",
                          "success_tracker",]:
            if prefix + attr_name in state_dict:
                setattr(self, attr_name, state_dict.pop(prefix + attr_name))

        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


    def update_schedules(self):
        pass    # updates the adaptive reward schedules, if necessary

    def update_state_counts(self, data):
        # right now this only tracks the total number of seen states
        # in the future we could augment this to update the reached state goals
        # TODO: reward might need this, so using this function to update reward
        #       internal state might be important
        self.state_counts += len(data.obs)

